import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import statsmodels.api as sm

# --- Load data ---
cost_file = 'CAPEX DATA.xlsb'
scenario_file = 'Scenario.xlsx'

df_cost = pd.read_excel(cost_file, sheet_name='CAPEX', engine='pyxlsb')[['Year', 'PEM']].dropna()
df_cost.columns = ['Year', 'Cost_PEM']
df_cost = df_cost.apply(pd.to_numeric, errors='coerce').dropna()

df_scenario = pd.read_excel(scenario_file, sheet_name='S2', skiprows=1)
df_scenario.columns = ['Year', 'S1', 'S2', 'S3']
df_scenario = df_scenario.dropna().apply(pd.to_numeric, errors='coerce').dropna()

# --- Regression using S3 as anchor ---
df_forecast_s3 = df_scenario[['Year', 'S3']].rename(columns={'S3': 'Deployment_PEM'})
df_merged = pd.merge(df_cost, df_forecast_s3, on='Year').sort_values('Year').reset_index(drop=True)

df_merged['log_cost'] = np.log(df_merged['Cost_PEM'])
df_merged['log_deployment'] = np.log(df_merged['Deployment_PEM'])

X = sm.add_constant(df_merged['log_deployment'])
y = df_merged['log_cost']
model = sm.OLS(y, X).fit()

b_pem = model.params['log_deployment']
sigma_pem = model.resid.std()

# --- Wright's Law stochastic simulation ---
def simulate_wrights_law(log_cost_0, log_deploy_diff, b, sigma, n_sim=1000):
    n_steps = len(log_deploy_diff)
    paths = np.zeros((n_sim, n_steps))
    for i in range(n_sim):
        log_cost = log_cost_0
        for t in range(n_steps):
            shock = np.random.normal(0, sigma)
            log_cost += b * log_deploy_diff[t] + shock
            paths[i, t] = log_cost
    return np.exp(paths)

# --- Forecast function ---
def prepare_forecast_for_scenario(scenario_label, df_scenario, df_cost, b, sigma, n_sim=1000):
    df_forecast = df_scenario[['Year', scenario_label]].rename(columns={scenario_label: 'Deployment_PEM'})
    df_merged = pd.merge(df_cost, df_forecast, on='Year').sort_values('Year').reset_index(drop=True)

    year_anchor = df_merged['Year'].max()
    cost_anchor = df_merged.loc[df_merged['Year'] == year_anchor, 'Cost_PEM'].values[0]
    deploy_anchor = df_merged.loc[df_merged['Year'] == year_anchor, 'Deployment_PEM'].values[0]

    df_future = df_forecast[df_forecast['Year'] >= 2024]
    deploy_forecast = df_future['Deployment_PEM'].values
    years_forecast = df_future['Year'].values

    log_cost_0 = np.log(cost_anchor)
    log_deploy_0 = np.log(deploy_anchor)
    log_deploy_forecast = np.log(deploy_forecast)
    log_deploy_diff = np.diff(np.insert(log_deploy_forecast, 0, log_deploy_0))

    cost_paths = simulate_wrights_law(log_cost_0, log_deploy_diff, b, sigma, n_sim=n_sim)

    return {
        'years': np.insert(years_forecast, 0, year_anchor),
        'median': np.insert(np.median(cost_paths, axis=0), 0, cost_anchor)
    }

# --- Forecasts for all scenarios ---
forecast_s1 = prepare_forecast_for_scenario('S1', df_scenario, df_cost, b_pem, sigma_pem)
forecast_s2 = prepare_forecast_for_scenario('S2', df_scenario, df_cost, b_pem, sigma_pem)
forecast_s3 = prepare_forecast_for_scenario('S3', df_scenario, df_cost, b_pem, sigma_pem)

# --- Final median plot ---
plt.figure(figsize=(8, 5))

plt.plot(forecast_s1['years'], forecast_s1['median'], label='S1 Median', linestyle='--', marker='o', color='steelblue')
plt.plot(forecast_s2['years'], forecast_s2['median'], label='S2 Median', linestyle='--', marker='o', color='seagreen')
plt.plot(forecast_s3['years'], forecast_s3['median'], label='S3 Median', linestyle='--', marker='o', color='darkorange')

plt.xlabel('Year', fontsize=14)
plt.ylabel('CAPEX (EUR/kW)', fontsize=14)
plt.title('Median CAPEX Forecast for PEM Electrolysers', fontsize=15)
plt.grid(True, linestyle='--', alpha=0.6)
plt.legend()
plt.tight_layout()
plt.show()
